$\newcommand{\xv}{\mathbf{x}} \newcommand{\Xv}{\mathbf{X}} \newcommand{\yv}{\mathbf{y}} \newcommand{\zv}{\mathbf{z}} \newcommand{\av}{\mathbf{a}} \newcommand{\Wv}{\mathbf{W}} \newcommand{\wv}{\mathbf{w}} \newcommand{\tv}{\mathbf{t}} \newcommand{\Tv}{\mathbf{T}} \newcommand{\muv}{\boldsymbol{\mu}} \newcommand{\sigmav}{\boldsymbol{\sigma}} \newcommand{\phiv}{\boldsymbol{\phi}} \newcommand{\Phiv}{\boldsymbol{\Phi}} \newcommand{\Sigmav}{\boldsymbol{\Sigma}} \newcommand{\Lambdav}{\boldsymbol{\Lambda}} \newcommand{\half}{\frac{1}{2}} \newcommand{\argmax}[1]{\underset{#1}{\operatorname{argmax}}} \newcommand{\argmin}[1]{\underset{#1}{\operatorname{argmin}}}$
Interpreting What a Neural Network Has Learned
Explainable Artificial Intelligence (XAI): Concepts, taxonomies, opportunities and challenges toward responsible AI, Arrieta, et al., Information Fusion, Volume 58, June 2020, Pages 82-115
"Given a certain audience, explainability refers to the details and reasons a model gives to make its functioning clear or easy to understand."
Here we will examine what the hidden units in a convolutional neural network have learned. This is most intuitive if we focus on classification problems involving images.
import numpy as np
import matplotlib.pyplot as plt
import torch
import pandas as pd
import os
from A6mysolution import *
# for regression problem
def rmse(a, b):
return np.sqrt(np.mean((a - b)**2))
# for classification problem
def percent_correct(a, b):
return 100 * np.mean(a == b)
# for classification problem
def confusion_matrix(Y_classes, T):
class_names = np.unique(T)
table = []
for true_class in class_names:
row = []
for Y_class in class_names:
row.append(100 * np.mean(Y_classes[T == true_class] == Y_class))
table.append(row)
conf_matrix = pd.DataFrame(table, index=class_names, columns=class_names)
conf_matrix.style.background_gradient(cmap='Blues').format("{:.1f}")
print(f'Percent Correct is {percent_correct(Y_classes, T)}')
return conf_matrix
def makeImages(nEach):
images = np.zeros((nEach * 2, 1, 20, 20)) # nSamples, nChannels, rows, columns
radii = 3 + np.random.randint(10 - 5, size=(nEach * 2, 1))
centers = np.zeros((nEach * 2, 2))
for i in range(nEach * 2):
r = radii[i, 0]
centers[i, :] = r + 1 + np.random.randint(18 - 2 * r, size=(1, 2))
x = int(centers[i, 0])
y = int(centers[i, 1])
if i < nEach:
# squares
images[i, 0, x - r:x + r, y + r] = 1.0
images[i, 0, x - r:x + r, y - r] = 1.0
images[i, 0, x - r, y - r:y + r] = 1.0
images[i, 0, x + r, y - r:y + r + 1] = 1.0
else:
# diamonds
images[i, 0, range(x - r, x), range(y, y + r)] = 1.0
images[i, 0, range(x - r, x), range(y, y - r, -1)] = 1.0
images[i, 0, range(x, x + r + 1), range(y + r, y - 1, -1)] = 1.0
images[i, 0, range(x, x + r), range(y - r, y)] = 1.0
# images += np.random.randn(*images.shape) * 0.5
T = np.zeros((nEach * 2, 1))
T[nEach:] = 1
return images, T
nEach = 1000
X, T = makeImages(nEach)
X = X.reshape(X.shape[0], -1)
print(X.shape, T.shape)
Xtest, Ttest = makeImages(nEach)
Xtest = Xtest.reshape(Xtest.shape[0], -1)
plt.plot(T);
(2000, 400) (2000, 1)
plt.imshow(-X[-1, :].reshape(20, 20), cmap='gray')
plt.xticks([])
plt.yticks([])
([], [])
plt.figure(figsize=(10, 3))
for i in range(10):
plt.subplot(2, 10, i + 1)
plt.imshow(-X[i, :].reshape(20,20), cmap='gray')
plt.xticks([])
plt.yticks([])
plt.subplot(2, 10, i + 11)
plt.imshow(-X[-i, :].reshape(20,20), cmap='gray')
plt.xticks([])
plt.yticks([])
nnet, learning_curve = train_for_classification(X, T, hidden_layers=[10],
n_epochs=500, learning_rate=0.01)
plt.plot(learning_curve);
nnet
Sequential( (0): Linear(in_features=400, out_features=10, bias=True) (1): Tanh() (2): Linear(in_features=10, out_features=2, bias=True) (3): LogSoftmax() )
Y = use(nnet, X)
Ytest = use(nnet, Xtest)
Y.shape
(2000, 2)
plt.subplot(2, 1, 1)
plt.plot(Y)
plt.subplot(2, 1, 2)
plt.plot(Ytest)
[<matplotlib.lines.Line2D at 0x7fa89861d370>, <matplotlib.lines.Line2D at 0x7fa89861d460>]
plt.plot(np.exp(Y))
[<matplotlib.lines.Line2D at 0x7fa898582b50>, <matplotlib.lines.Line2D at 0x7fa898582c40>]
Y_classes = np.argmax(Y, axis=1).reshape(-1, 1) # To keep 2-dimensional shape
plt.plot(Y_classes, 'o', label='Predicted')
plt.plot(T + 0.1, 'o', label='Target')
plt.legend();
Y_classes_test = np.argmax(Ytest, axis=1).reshape(-1, 1) # To keep 2-dimensional shape
plt.plot(Y_classes_test, 'o', label='Predicted')
plt.plot(T + 0.1, 'o', label='Target')
plt.legend();
Y.shape
(2000, 2)
confusion_matrix(Y_classes_test, Ttest)
Percent Correct is 99.05000000000001
0.0 | 1.0 | |
---|---|---|
0.0 | 99.1 | 0.9 |
1.0 | 1.0 | 99.0 |
def forward_all_layers(nnet, X):
X = torch.from_numpy(X).float()
Ys = [X]
for layer in nnet:
Ys.append(layer(Ys[-1]))
Ys = [Y.detach().numpy() for Y in Ys]
return Ys
Y_square = forward_all_layers(nnet, X[:10, :])
Y_diamond = forward_all_layers(nnet, X[-10:, :])
nnet
Sequential( (0): Linear(in_features=400, out_features=10, bias=True) (1): Tanh() (2): Linear(in_features=10, out_features=2, bias=True) (3): LogSoftmax() )
len(Y_square)
5
Y_square[0].shape
(10, 400)
Y_square[1].shape
(10, 10)
plt.plot(Y_square[1]);
plt.plot(Y_square[2]);
both = np.vstack((Y_square[2], Y_diamond[2]))
plt.plot(both);
plt.figure(figsize=(15, 3))
for unit in range(10):
plt.subplot(1, 10, unit + 1)
plt.plot(both[:, unit])
plt.tight_layout()
plt.plot(both[:, 9])
[<matplotlib.lines.Line2D at 0x7fa898a4fcd0>]
nnet
Sequential( (0): Linear(in_features=400, out_features=10, bias=True) (1): Tanh() (2): Linear(in_features=10, out_features=2, bias=True) (3): LogSoftmax() )
nnet[0].parameters()
<generator object Module.parameters at 0x7fa898a1eac0>
list(nnet[0].parameters())
[Parameter containing: tensor([[-0.0235, 0.0183, -0.0043, ..., -0.0283, -0.0472, 0.0493], [-0.0305, 0.0305, 0.0375, ..., 0.0076, -0.0254, -0.0216], [-0.0205, -0.0357, -0.0086, ..., -0.0464, 0.0402, 0.0307], ..., [-0.0139, 0.0290, 0.0346, ..., -0.0098, 0.0007, 0.0092], [ 0.0226, 0.0153, 0.0103, ..., -0.0146, 0.0159, 0.0084], [ 0.0174, 0.0253, -0.0183, ..., 0.0091, 0.0213, 0.0307]], requires_grad=True), Parameter containing: tensor([-0.8432, -0.7838, -0.5426, -0.7839, 0.6757, -0.6754, -0.7834, 0.7684, 0.8641, 0.5593], requires_grad=True)]
W = list(nnet[0].parameters())[0]
W = W.detach().numpy()
W.shape
(10, 400)
W = W.T
W.shape
--------------------------------------------------------------------------- AttributeError Traceback (most recent call last) <ipython-input-32-4b76fe59cac6> in <module> ----> 1 W = W.TRdYlGn 2 W.shape AttributeError: 'numpy.ndarray' object has no attribute 'TRdYlGn'
plt.plot(W);
plt.plot(W[:, 0])
plt.imshow(W[:, 0].reshape(20, 20), cmap='RdYlGn')
plt.colorbar()
plt.figure(figsize=(15, 3))
for i in range(10):
plt.subplot(2, 10, i + 1)
plt.imshow(W[:, i].reshape(20,20), cmap='RdYlGn')
plt.xticks([])
plt.yticks([])
plt.colorbar()
plt.subplot(2, 10, i + 11)
plt.plot(both[:, i])
X.shape
plt.imshow(X[4,:].reshape(20, 20), cmap='gray')
Let's automate these steps in a function, so we can try different numbers of hidden units and layers.
nnet
Sequential( (0): Linear(in_features=400, out_features=10, bias=True) (1): Tanh() (2): Linear(in_features=10, out_features=2, bias=True) (3): LogSoftmax() )
Wout = list(nnet[2].parameters())[0]
Wout = Wout.detach().numpy()
Wout = Wout.T
Wout.shape
(10, 2)
def run_again(hiddens):
nnet, learning_curve = train_for_classification(X, T, hidden_layers=hiddens,
n_epochs=1000, learning_rate=0.01)
plt.figure()
plt.plot(learning_curve)
Y_square = forward_all_layers(nnet, X[:10, :])
Y_diamond = forward_all_layers(nnet, X[-10:, :])
both = np.vstack((Y_square[2], Y_diamond[2]))
W = list(nnet[0].parameters())[0]
W = W.detach().numpy()
W = W.T
Wout = list(nnet[2].parameters())[0]
Wout = Wout.detach().numpy()
Wout = Wout.T
plt.figure(figsize=(15, 3))
n_units = hiddens[0]
size = int(np.sqrt(X.shape[1]))
for i in range(n_units):
plt.subplot(2, n_units, i + 1)
plt.imshow(W[:, i].reshape(size, size), cmap='RdYlGn')
plt.colorbar()
plt.xticks([])
plt.yticks([])
plt.subplot(2, n_units, i + 1 + n_units)
plt.plot(both[:, i])
plt.title(f'{Wout[i,0]:.1f},{Wout[i,1]:.1f}')
Y = use(nnet, X)
Y_classes = np.argmax(Y, axis=1).reshape(-1, 1)
print(confusion_matrix(Y_classes, T))
Ytest = use(nnet, Xtest)
Y_classes_test = np.argmax(Ytest, axis=1).reshape(-1, 1)
print(confusion_matrix(Y_classes_test, Ttest))
run_again([10])
Percent Correct is 100.0 0.0 1.0 0.0 100.0 0.0 1.0 0.0 100.0 Percent Correct is 99.45 0.0 1.0 0.0 99.8 0.2 1.0 0.9 99.1
if os.path.isfile('small_mnist.npz'):
print('Reading data from \'small_mnist.npz\'.')
small_mnist = np.load('small_mnist.npz')
else:
import shlex
import subprocess
print('Downloading small_mnist.npz from CS545 site.')
cmd = 'curl "https://www.cs.colostate.edu/~anderson/cs545/notebooks/small_mnist.npz" -o "small_mnist.npz"'
subprocess.call(shlex.split(cmd))
small_mnist = np.load('small_mnist.npz')
X = small_mnist['X']
T = small_mnist['T']
X.shape, T.shape
Reading data from 'small_mnist.npz'.
((1000, 784), (1000, 1))
plt.imshow(-X[0, :].reshape(28, 28), cmap='gray')
<matplotlib.image.AxesImage at 0x7fa87459ccd0>
Randomly partition the data into 80% for training and 20% for testing, using the following code cells.
n_samples = X.shape[0]
n_train = int(n_samples * 0.6)
rows = np.arange(n_samples)
np.random.shuffle(rows)
Xtrain = X[rows[:n_train], :]
Ttrain = T[rows[:n_train], :]
Xtest = X[rows[n_train:], :]
Ttest = T[rows[n_train:], :]
def run_again_mnist(hiddens):
nnet, learning_curve = train_for_classification(Xtrain, Ttrain, hidden_layers=hiddens,
n_epochs=1000, learning_rate=0.01)
plt.figure()
plt.plot(learning_curve)
Y_square = forward_all_layers(nnet, X[:10, :])
Y_diamond = forward_all_layers(nnet, X[-10:, :])
both = np.vstack((Y_square[2], Y_diamond[2]))
W = list(nnet[0].parameters())[0]
W = W.detach().numpy()
W = W.T
Wout = list(nnet[2].parameters())[0]
Wout = Wout.detach().numpy()
Wout = Wout.T
plt.figure(figsize=(15, 15))
n_units = hiddens[0]
size = int(np.sqrt(X.shape[1]))
n_rows = int(np.sqrt(n_units) + 1)
for i in range(n_units):
plt.subplot(n_rows, n_rows, i + 1)
plt.imshow(W[:, i].reshape(size, size), cmap='RdYlGn')
plt.colorbar()
plt.xticks([])
plt.yticks([])
Y = use(nnet, Xtrain)
Y_classes = np.argmax(Y, axis=1).reshape(-1, 1)
display(confusion_matrix(Y_classes, Ttrain))
Ytest = use(nnet, Xtest)
Y_classes_test = np.argmax(Ytest, axis=1).reshape(-1, 1)
display(confusion_matrix(Y_classes_test, Ttest))
run_again_mnist([20, 20, 20])
Percent Correct is 100.0
0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | |
---|---|---|---|---|---|---|---|---|---|---|
0 | 100.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
1 | 0.0 | 100.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
2 | 0.0 | 0.0 | 100.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
3 | 0.0 | 0.0 | 0.0 | 100.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
4 | 0.0 | 0.0 | 0.0 | 0.0 | 100.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
5 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 100.0 | 0.0 | 0.0 | 0.0 | 0.0 |
6 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 100.0 | 0.0 | 0.0 | 0.0 |
7 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 100.0 | 0.0 | 0.0 |
8 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 100.0 | 0.0 |
9 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 100.0 |
Percent Correct is 84.25
0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | |
---|---|---|---|---|---|---|---|---|---|---|
0 | 87.234043 | 0.000000 | 0.000000 | 2.127660 | 0.000000 | 10.638298 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
1 | 0.000000 | 93.617021 | 2.127660 | 0.000000 | 0.000000 | 2.127660 | 0.000000 | 0.000000 | 0.000000 | 2.127660 |
2 | 0.000000 | 2.777778 | 91.666667 | 0.000000 | 0.000000 | 0.000000 | 2.777778 | 0.000000 | 2.777778 | 0.000000 |
3 | 0.000000 | 0.000000 | 0.000000 | 79.069767 | 0.000000 | 11.627907 | 0.000000 | 4.651163 | 4.651163 | 0.000000 |
4 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 85.714286 | 0.000000 | 5.714286 | 2.857143 | 0.000000 | 5.714286 |
5 | 4.347826 | 0.000000 | 2.173913 | 4.347826 | 0.000000 | 76.086957 | 0.000000 | 0.000000 | 13.043478 | 0.000000 |
6 | 2.325581 | 0.000000 | 0.000000 | 0.000000 | 2.325581 | 0.000000 | 95.348837 | 0.000000 | 0.000000 | 0.000000 |
7 | 2.777778 | 2.777778 | 2.777778 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 86.111111 | 0.000000 | 5.555556 |
8 | 0.000000 | 3.030303 | 3.030303 | 9.090909 | 0.000000 | 3.030303 | 3.030303 | 9.090909 | 69.696970 | 0.000000 |
9 | 0.000000 | 0.000000 | 2.941176 | 0.000000 | 8.823529 | 2.941176 | 0.000000 | 8.823529 | 2.941176 | 73.529412 |